{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "(MR-)MAML.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "jb2pLvUgvXJj",
        "colab_type": "code",
        "outputId": "1031c58a-551b-4575-9c77-bf8e34af5aae",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "# Copyright 2019 Google LLC.\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\")\n",
        "\n",
        "import tensorflow as tf\n",
        "import tensorflow.keras as keras\n",
        "from tensorflow.keras import layers\n",
        "import tensorflow.keras.backend as keras_backend\n",
        "tf.keras.backend.set_floatx('float32')\n",
        "import tensorflow_probability as tfp\n",
        "from tensorflow_probability.python.layers import util as tfp_layers_util\n",
        "\n",
        "import random\n",
        "import sys\n",
        "import time\n",
        "import os\n",
        "\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "import seaborn as sns\n",
        "\n",
        "print(tf.__version__) # use tensorflow version >= 2.0.0\n",
        "#pip install tensorflow=2.0.0\n",
        "#pip install --upgrade tensorflow-probability\n",
        "\n",
        "exp_type = 'MAML'  # choose from 'MAML', 'MR-MAML-W', 'MR-MAML-A'"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "2.0.0\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fosHnP6GwFZ-",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class SinusoidGenerator():\n",
        "    def __init__(self, K=10, width=5, K_amp=20, phase=0, amps = None, amp_ind=None,  amplitude =None, seed = None):\n",
        "        '''\n",
        "        Args:\n",
        "            K: batch size. Number of values sampled at every batch.\n",
        "            amplitude: Sine wave amplitude.\n",
        "            pahse: Sine wave phase.\n",
        "        '''\n",
        "        self.K = K\n",
        "        self.width = width\n",
        "        self.K_amp = K_amp\n",
        "        self.phase = phase\n",
        "        self.seed = seed\n",
        "        self.x = self._sample_x()\n",
        "        self.amp_ind = amp_ind if amp_ind is not None else random.randint(0,self.K_amp-5)\n",
        "        self.amps = amps if amps is not None else np.linspace(0.1,4,self.K_amp)\n",
        "        self.amplitude = amplitude if amplitude is not None else self.amps[self.amp_ind]\n",
        "\n",
        "    def _sample_x(self):\n",
        "        if self.seed is not None:\n",
        "          np.random.seed(self.seed)\n",
        "        return np.random.uniform(-self.width, self.width, self.K)\n",
        "\n",
        "\n",
        "    def batch(self, noise_scale, x = None):\n",
        "        '''return xa is [K, d_x+d_a], y is [K, d_y]'''\n",
        "        if x is None:\n",
        "          x = self._sample_x()\n",
        "        x = x[:, None]\n",
        "        amp = np.zeros([1, self.K_amp])\n",
        "        amp[0,self.amp_ind] = 1\n",
        "        amp = np.tile(amp, x.shape)\n",
        "        xa = np.concatenate([x, amp], axis = 1)\n",
        "        y = self.amplitude * np.sin(x - self.phase) + np.random.normal(scale = noise_scale, size = x.shape)\n",
        "        return xa, y\n",
        "\n",
        "    def equally_spaced_samples(self, K=None, width=None):\n",
        "        '''Returns K equally spaced samples.'''\n",
        "        if K is None:\n",
        "            K = self.K\n",
        "        if width is None:\n",
        "            width = self.width\n",
        "        return self.batch(noise_scale = 0, x=np.linspace(-width+0.5, width-0.5, K))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Df53q7-VwI2P",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "noise_scale = 0.1 #@param {type:\"number\"}\n",
        "n_obs = 20 #@param {type:\"number\"}\n",
        "n_context = 10 #@param {type:\"number\"}\n",
        "K_amp = 20 #@param {type:\"number\"}\n",
        "x_width = 5 #@param {type:\"number\"}\n",
        "n_iter = 20000 #@param {type:\"number\"}\n",
        "amps = np.linspace(0.1,4,K_amp)\n",
        "lr_inner = 0.01 #@param {type:\"number\"}\n",
        "dim_w = 5 #@param {type:\"number\"}\n",
        "train_ds = [SinusoidGenerator(K=n_context, width = x_width, \\\n",
        "                              K_amp = K_amp, amps = amps) \\\n",
        "                              for _ in range(n_iter)]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jYAoMD0rwQ28",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "\n",
        "class SineModel(keras.Model):\n",
        "    def __init__(self):\n",
        "        super(SineModel, self).__init__() # python 2 syntax\n",
        "        # super().__init__() # python 3 syntax\n",
        "        self.hidden1 = keras.layers.Dense(40)\n",
        "        self.hidden2 = keras.layers.Dense(40)\n",
        "        self.out = keras.layers.Dense(1)\n",
        "\n",
        "    def call(self, x):\n",
        "        x = keras.activations.relu(self.hidden1(x))\n",
        "        x = keras.activations.relu(self.hidden2(x))\n",
        "        x = self.out(x)\n",
        "        return x\n",
        "\n",
        "\n",
        "def kl_qp_gaussian(mu_q, sigma_q, mu_p, sigma_p):\n",
        "  \"\"\"Kullback-Leibler KL(N(mu_q), Diag(sigma_q^2) || N(mu_p), Diag(sigma_p^2))\"\"\"\n",
        "  sigma2_q = tf.square(sigma_q) + 1e-16\n",
        "  sigma2_p = tf.square(sigma_p) + 1e-16\n",
        "  temp = tf.math.log(sigma2_p) - tf.math.log(sigma2_q) - 1.0 + \\\n",
        "          sigma2_q / sigma2_p + tf.square(mu_q - mu_p) / sigma2_p  #n_target * d_w\n",
        "  kl = 0.5 * tf.reduce_mean(temp, axis = 1)\n",
        "  return tf.reduce_mean(kl)\n",
        "\n",
        "def copy_model(model, x=None, input_shape=None):\n",
        "    '''\n",
        "      Copy model weights to a new model.\n",
        "      Args:\n",
        "          model: model to be copied.\n",
        "          x: An input example.\n",
        "    '''\n",
        "    copied_model = SineModel()\n",
        "    if x is not None:\n",
        "      copied_model.call(tf.convert_to_tensor(x))\n",
        "    if input_shape is not None:\n",
        "      copied_model.build(tf.TensorShape([None,input_shape]))\n",
        "    copied_model.set_weights(model.get_weights())\n",
        "    return copied_model\n",
        "\n",
        "def np_to_tensor(list_of_numpy_objs):\n",
        "    return (tf.convert_to_tensor(obj, dtype=tf.float32) for obj in list_of_numpy_objs)\n",
        "\n",
        "\n",
        "def compute_loss(model, xa, y):\n",
        "    y_hat = model.call(xa)\n",
        "    loss = keras_backend.mean(keras.losses.mean_squared_error(y, y_hat))\n",
        "    return loss, y_hat\n",
        "\n",
        "\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4McD728ixTbm",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def train_batch(xa, y, model, optimizer, encoder=None):\n",
        "    tensor_xa, tensor_y = np_to_tensor((xa, y))\n",
        "    if exp_type == 'MAML':\n",
        "      with tf.GradientTape() as tape:\n",
        "          loss, _ = compute_loss(model, tensor_xa, tensor_y)\n",
        "    if exp_type == 'MR-MAML-W':\n",
        "      w = encoder(tensor_xa)\n",
        "      with tf.GradientTape() as tape:\n",
        "          y_hat = model.call(w)\n",
        "          loss = keras_backend.mean(keras.losses.mean_squared_error(tensor_y, y_hat))\n",
        "    if exp_type == 'MR-MAML-A':\n",
        "      _, w, _ = encoder(tensor_xa)\n",
        "      with tf.GradientTape() as tape:\n",
        "          y_hat = model.call(w)\n",
        "          loss = keras_backend.mean(keras.losses.mean_squared_error(y, y_hat))\n",
        "    gradients = tape.gradient(loss, model.trainable_variables)\n",
        "    optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n",
        "    return loss\n",
        "\n",
        "\n",
        "def test_inner_loop(model, optimizer, xa_context, y_context, xa_target, y_target, num_steps, encoder=None):\n",
        "    inner_record = []\n",
        "    tensor_xa_target, tensor_y_target = np_to_tensor((xa_target, y_target))\n",
        "    if exp_type == 'MAML':\n",
        "      w_target = tensor_xa_target\n",
        "    if exp_type == 'MR-MAML-W':\n",
        "      w_target = encoder(tensor_xa_target)\n",
        "    if exp_type == 'MR-MAML-A':\n",
        "      _, w_target, _ = encoder(tensor_xa_target)\n",
        "\n",
        "    for step in range(0, np.max(num_steps) + 1):\n",
        "        if step in num_steps:\n",
        "          if exp_type == 'MAML':\n",
        "            loss, y_hat = compute_loss(model, w_target, tensor_y_target)\n",
        "          else:\n",
        "            y_hat = model.call(w_target)\n",
        "            loss = keras_backend.mean(keras.losses.mean_squared_error(tensor_y_target, y_hat))\n",
        "          inner_record.append((step, y_hat, loss))\n",
        "        loss = train_batch(xa_context, y_context, model, optimizer, encoder)\n",
        "    return inner_record\n",
        "\n",
        "\n",
        "def eval_sinewave_for_test(model, sinusoid_generator, num_steps=(0, 1, 10), encoder=None, learning_rate = lr_inner, ax = None, legend= False):\n",
        "    # data for training\n",
        "    xa_context, y_context = sinusoid_generator.batch(noise_scale = noise_scale)\n",
        "    y_context = y_context + np.random.normal(scale = noise_scale, size = y_context.shape)\n",
        "    # data for validation\n",
        "    xa_target, y_target = sinusoid_generator.equally_spaced_samples(K = 200, width = 5)\n",
        "    y_target = y_target + np.random.normal(scale = noise_scale, size = y_target.shape)\n",
        "\n",
        "    # copy model so we can use the same model multiple times\n",
        "    if exp_type == 'MAML':\n",
        "      copied_model = copy_model(model, x = xa_context)\n",
        "    else:\n",
        "      copied_model = copy_model(model, input_shape=dim_w)\n",
        "    optimizer = keras.optimizers.SGD(learning_rate=learning_rate)\n",
        "    inner_record = test_inner_loop(copied_model, optimizer, xa_context, y_context, xa_target, y_target, num_steps, encoder)\n",
        "\n",
        "    # plot\n",
        "    if ax is not None:\n",
        "      plt.sca(ax)\n",
        "      x_context = xa_context[:,0,None]\n",
        "      x_target = xa_target[:,0,None]\n",
        "      train, = plt.plot(x_context, y_context, '^')\n",
        "      ground_truth, = plt.plot(x_target, y_target0, linewidth=2.0)\n",
        "      plots = [train, ground_truth]\n",
        "      legends = ['Context Points', 'True Function']\n",
        "      for n, y_hat, loss in inner_record:\n",
        "          cur, = plt.plot(x_target, y_hat[:, 0], '--')\n",
        "          plots.append(cur)\n",
        "          legends.append('After {} Steps'.format(n))\n",
        "      if legend:\n",
        "        plt.legend(plots, legends, loc='center left', bbox_to_anchor=(1, 0.5))\n",
        "      plt.ylim(-6, 6)\n",
        "      plt.axvline(x=-sinusoid_generator.width, linestyle='--')\n",
        "      plt.axvline(x=sinusoid_generator.width,linestyle='--')\n",
        "    return inner_record"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NPbj4ge1KGR4",
        "colab_type": "code",
        "outputId": "02376a7f-90ff-40cf-837a-7205c2e0b85a",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 272
        }
      },
      "source": [
        "exp_type = 'MAML'\n",
        "if exp_type == 'MAML':\n",
        "  model = SineModel()\n",
        "  model.build((None, K_amp+1))\n",
        "\n",
        "  dataset = train_ds\n",
        "  optimizer = keras.optimizers.Adam()\n",
        "  total_loss = 0\n",
        "  n_iter = 15000\n",
        "  losses = []\n",
        "\n",
        "  for i, t in enumerate(random.sample(dataset, n_iter)):\n",
        "      xa_train, y_train = np_to_tensor(t.batch(noise_scale = noise_scale))\n",
        "\n",
        "      with tf.GradientTape(watch_accessed_variables=False) as test_tape:\n",
        "          test_tape.watch(model.trainable_variables)\n",
        "          with tf.GradientTape() as train_tape:\n",
        "              train_loss, _ = compute_loss(model, xa_train, y_train)\n",
        "          model_copy = copy_model(model, xa_train)\n",
        "          gradients_inner = train_tape.gradient(train_loss, model.trainable_variables) # \\nabla_{\\theta}\n",
        "\n",
        "          k = 0\n",
        "          for j in range(len(model_copy.layers)):\n",
        "              model_copy.layers[j].kernel = tf.subtract(model.layers[j].kernel,  # \\phi_t = T(\\theta, \\nabla_{\\theta})\n",
        "                          tf.multiply(lr_inner, gradients_inner[k]))\n",
        "              model_copy.layers[j].bias = tf.subtract(model.layers[j].bias,\n",
        "                          tf.multiply(lr_inner, gradients_inner[k+1]))\n",
        "              k += 2\n",
        "          xa_validation, y_validation = np_to_tensor(t.batch(noise_scale = noise_scale))\n",
        "          test_loss, y_hat = compute_loss(model_copy, xa_validation, y_validation) # test_loss\n",
        "      gradients_outer = test_tape.gradient(test_loss, model.trainable_variables)\n",
        "      optimizer.apply_gradients(zip(gradients_outer, model.trainable_variables))\n",
        "\n",
        "\n",
        "      total_loss += test_loss\n",
        "      loss = total_loss / (i+1.0)\n",
        "      if i % 1000 == 0:\n",
        "          print('Step {}: loss = {}'.format(i, loss))"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Step 0: loss = 0.5836441516876221\n",
            "Step 1000: loss = 0.7144132256507874\n",
            "Step 2000: loss = 0.4053805470466614\n",
            "Step 3000: loss = 0.2834143042564392\n",
            "Step 4000: loss = 0.2200154811143875\n",
            "Step 5000: loss = 0.18113426864147186\n",
            "Step 6000: loss = 0.154592826962471\n",
            "Step 7000: loss = 0.1353120058774948\n",
            "Step 8000: loss = 0.12070710957050323\n",
            "Step 9000: loss = 0.1094995066523552\n",
            "Step 10000: loss = 0.1004406288266182\n",
            "Step 11000: loss = 0.09283847361803055\n",
            "Step 12000: loss = 0.08659099042415619\n",
            "Step 13000: loss = 0.08126141875982285\n",
            "Step 14000: loss = 0.07664626091718674\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "KAIcBZspRPEu",
        "colab_type": "code",
        "outputId": "0aff0f48-fd5d-4cd5-bdee-9d3a8c5a7422",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "if exp_type == 'MAML':\n",
        "  tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)\n",
        "  n_context = 5\n",
        "  n_test_task = 100\n",
        "  errs = []\n",
        "  for ii in range(n_test_task):\n",
        "    np.random.seed(ii)\n",
        "    A = np.random.uniform(low = amps[0], high = amps[-1])\n",
        "    test_ds = SinusoidGenerator(K=n_context, seed = ii, amplitude = A, amp_ind= random.randint(0,K_amp-5))\n",
        "    inner_record = eval_sinewave_for_test(model,  test_ds, num_steps=(0, 1, 5, 100));\n",
        "    errs.append(inner_record[-1][2].numpy())\n",
        "\n",
        "  print('Model is', exp_type, 'meta-test MSE is', np.mean(errs) )"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Model is MAML meta-test MSE is 0.51450217\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "370SevfOR6D_",
        "colab_type": "text"
      },
      "source": [
        "# Training & Testing for MR-MAML(W)"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "dr1j1BSVL14X",
        "colab_type": "code",
        "outputId": "e625e62e-0457-4dc6-bfb5-2e933279f179",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 255
        }
      },
      "source": [
        "if exp_type == 'MR-MAML-W':\n",
        "\n",
        "  model = SineModel()\n",
        "  dataset = train_ds\n",
        "  optimizer = keras.optimizers.Adam()\n",
        "\n",
        "  Beta = 5e-5\n",
        "  learning_rate = 1e-3\n",
        "  n_iter = 15000\n",
        "  model.build((None, dim_w))\n",
        "\n",
        "  kernel_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(untransformed_scale_initializer=tf.compat.v1.initializers.random_normal(\n",
        "      mean=-50., stddev=0.1))\n",
        "\n",
        "  encoder_w = tf.keras.Sequential([\n",
        "          tfp.layers.DenseReparameterization(100, activation=tf.nn.relu, kernel_posterior_fn=kernel_posterior_fn,input_shape=(1 + K_amp,)),\n",
        "          tfp.layers.DenseReparameterization(dim_w,kernel_posterior_fn=kernel_posterior_fn),\n",
        "        ])\n",
        "\n",
        "  total_loss = 0\n",
        "  losses = []\n",
        "  start = time.time()\n",
        "\n",
        "  for i, t in enumerate(random.sample(dataset, n_iter)):\n",
        "      xa_train, y_train = np_to_tensor(t.batch(noise_scale = noise_scale))   #[K, 1]\n",
        "\n",
        "      x_validation = np.random.uniform(-x_width, x_width, n_obs - n_context)\n",
        "      xa_validation, y_validation = np_to_tensor(t.batch(noise_scale = noise_scale, x = x_validation))\n",
        "\n",
        "      all_var = encoder_w.trainable_variables + model.trainable_variables\n",
        "      with tf.GradientTape(watch_accessed_variables=False) as test_tape:\n",
        "          test_tape.watch(all_var)\n",
        "          with tf.GradientTape() as train_tape:\n",
        "              w_train = encoder_w(xa_train)\n",
        "              y_hat_train = model.call(w_train)\n",
        "              train_loss =  keras_backend.mean(keras.losses.mean_squared_error(y_train, y_hat_train)) # K*1\n",
        "          gradients_inner = train_tape.gradient(train_loss, model.trainable_variables) # \\nabla_{\\theta}\n",
        "          model_copy = copy_model(model, x = w_train)\n",
        "          k = 0\n",
        "          for j in range(len(model_copy.layers)):\n",
        "              model_copy.layers[j].kernel = tf.subtract(model.layers[j].kernel,  # \\phi_t = T(\\theta, \\nabla_{\\theta})\n",
        "                          tf.multiply(lr_inner, gradients_inner[k]))\n",
        "              model_copy.layers[j].bias = tf.subtract(model.layers[j].bias,\n",
        "                          tf.multiply(lr_inner, gradients_inner[k+1]))\n",
        "              k += 2\n",
        "\n",
        "          w_validation = encoder_w(xa_validation)\n",
        "          y_hat_validation = model_copy.call(w_validation)\n",
        "          mse_loss = keras_backend.mean(keras.losses.mean_squared_error(y_validation, y_hat_validation))\n",
        "          kl_loss = Beta * sum(encoder_w.losses)\n",
        "          validation_loss = mse_loss + kl_loss\n",
        "\n",
        "      gradients_outer = test_tape.gradient(validation_loss,all_var)\n",
        "      keras.optimizers.Adam(learning_rate=learning_rate).apply_gradients(zip(gradients_outer, all_var))\n",
        "\n",
        "      losses.append(validation_loss.numpy())\n",
        "\n",
        "      if i % 1000 == 0 and i > 0:\n",
        "          print('Step {}:'.format(i), 'loss=', np.mean(losses))\n",
        "          losses = []"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Step 1000: loss= 2.6914458\n",
            "Step 2000: loss= 2.4870665\n",
            "Step 3000: loss= 2.4284792\n",
            "Step 4000: loss= 2.3726428\n",
            "Step 5000: loss= 2.3125937\n",
            "Step 6000: loss= 2.228668\n",
            "Step 7000: loss= 2.1762276\n",
            "Step 8000: loss= 2.1387603\n",
            "Step 9000: loss= 2.112448\n",
            "Step 10000: loss= 2.1087198\n",
            "Step 11000: loss= 2.10187\n",
            "Step 12000: loss= 2.102722\n",
            "Step 13000: loss= 2.1002984\n",
            "Step 14000: loss= 2.0911772\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "axs-TERYbQa9",
        "colab_type": "code",
        "outputId": "7a176904-4f8e-4c64-cc1b-83014df3abb9",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "if exp_type == 'MR-MAML-W':\n",
        "  n_context = 5\n",
        "  n_test_task = 100\n",
        "  errs = []\n",
        "  for ii in range(n_test_task):\n",
        "    np.random.seed(ii)\n",
        "    A = np.random.uniform(low = amps[0], high = amps[-1])\n",
        "    test_ds = SinusoidGenerator(K=n_context, seed = ii, amplitude = A, amp_ind= random.randint(0,K_amp-5))\n",
        "    inner_record = eval_sinewave_for_test(model,  test_ds, num_steps=(0, 1, 5, 100), encoder=encoder_w);\n",
        "    errs.append(inner_record[-1][2].numpy())\n",
        "\n",
        "  print('Model is', exp_type, ', meta-test MSE is', np.mean(errs) )"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Model is MR-MAML-W , meta-test MSE is 0.16159104\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fXfX2JvcAATy",
        "colab_type": "text"
      },
      "source": [
        "#Training & Testing for MR-MAML(A)"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "W9gwqYmGAACR",
        "colab_type": "code",
        "outputId": "273fa827-d7ed-44eb-b6c8-2f259f4818af",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 170
        }
      },
      "source": [
        "\n",
        "if exp_type == 'MR-MAML-A':\n",
        "  class Encoder(keras.Model):\n",
        "    def __init__(self, dim_w=5, name='encoder', **kwargs):\n",
        "      # super().__init__(name = name)\n",
        "      super(Encoder, self).__init__(name = name)\n",
        "      self.dense_proj = layers.Dense(80, activation='relu')\n",
        "      self.dense_mu = layers.Dense(dim_w)\n",
        "      self.dense_sigma_w = layers.Dense(dim_w)\n",
        "\n",
        "    def call(self, inputs):\n",
        "      h = self.dense_proj(inputs)\n",
        "      mu_w = self.dense_mu(h)\n",
        "      sigma_w = self.dense_sigma_w(h)\n",
        "      sigma_w = tf.nn.softplus(sigma_w)\n",
        "      ws = mu_w + tf.random.normal(tf.shape(mu_w)) * sigma_w\n",
        "      return ws, mu_w, sigma_w\n",
        "\n",
        "  model = SineModel()\n",
        "  model.build((None, dim_w))\n",
        "  encoder_w = Encoder(dim_w = dim_w)\n",
        "  encoder_w.build((None, K_amp+1))\n",
        "  Beta = 5.0\n",
        "  n_iter = 10000\n",
        "  dataset = train_ds\n",
        "  optimizer = keras.optimizers.Adam()\n",
        "  losses = [];\n",
        "\n",
        "  for i, t in enumerate(random.sample(dataset, n_iter)):\n",
        "      xa_train, y_train = np_to_tensor(t.batch(noise_scale = noise_scale))   #[K, 1]\n",
        "\n",
        "      with tf.GradientTape(watch_accessed_variables=False) as test_tape, tf.GradientTape(watch_accessed_variables=False) as encoder_test_tape:\n",
        "          test_tape.watch(model.trainable_variables)\n",
        "          encoder_test_tape.watch(encoder_w.trainable_variables)\n",
        "          with tf.GradientTape() as train_tape:\n",
        "              w_train, _, _ = encoder_w(xa_train)\n",
        "              y_hat = model.call(w_train)\n",
        "              train_loss = keras_backend.mean(keras.losses.mean_squared_error(y_train, y_hat))\n",
        "          model_copy = copy_model(model, x=w_train)\n",
        "          gradients_inner = train_tape.gradient(train_loss, model.trainable_variables) # \\nabla_{\\theta}\n",
        "          k = 0\n",
        "          for j in range(len(model_copy.layers)):\n",
        "              model_copy.layers[j].kernel = tf.subtract(model.layers[j].kernel,  # \\phi_t = T(\\theta, \\nabla_{\\theta})\n",
        "                          tf.multiply(lr_inner, gradients_inner[k]))\n",
        "              model_copy.layers[j].bias = tf.subtract(model.layers[j].bias,\n",
        "                          tf.multiply(lr_inner, gradients_inner[k+1]))\n",
        "              k += 2\n",
        "          x_validation = np.random.uniform(-x_width, x_width, n_obs - n_context)\n",
        "          xa_validation, y_validation = np_to_tensor(t.batch(noise_scale = noise_scale, x = x_validation))\n",
        "\n",
        "          w_validation, w_mu_validation, w_sigma_validation = encoder_w(xa_validation)\n",
        "          test_mse, _ = compute_loss(model_copy, w_validation, y_validation)\n",
        "          kl_ib = kl_qp_gaussian(w_mu_validation, w_sigma_validation,\n",
        "                            tf.zeros(tf.shape(w_mu_validation)), tf.ones(tf.shape(w_sigma_validation)))\n",
        "          test_loss = test_mse + Beta * kl_ib\n",
        "\n",
        "      gradients_outer = test_tape.gradient(test_mse, model.trainable_variables)\n",
        "      optimizer.apply_gradients(zip(gradients_outer, model.trainable_variables))\n",
        "\n",
        "      gradients = encoder_test_tape.gradient(test_loss,encoder_w.trainable_variables)\n",
        "      keras.optimizers.Adam(learning_rate=0.001).apply_gradients(zip(gradients, encoder_w.trainable_variables))\n",
        "\n",
        "      losses.append(test_loss)\n",
        "\n",
        "      if i % 1000 == 0 and i > 0:\n",
        "          print('Step {}:'.format(i), 'loss = ', np.mean(losses))"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Step 1000: loss =  1.9501201\n",
            "Step 2000: loss =  1.9419937\n",
            "Step 3000: loss =  1.8764127\n",
            "Step 4000: loss =  1.8179001\n",
            "Step 5000: loss =  1.7535788\n",
            "Step 6000: loss =  1.6897625\n",
            "Step 7000: loss =  1.632552\n",
            "Step 8000: loss =  1.57314\n",
            "Step 9000: loss =  1.5216928\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "PBPR2byCbaIs",
        "colab_type": "code",
        "outputId": "fa0adae9-446a-470c-88b0-0fbc5b44a61a",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "if exp_type == 'MR-MAML-A':\n",
        "  n_context = 5\n",
        "  n_test_task = 100\n",
        "  errs = []\n",
        "  for ii in range(n_test_task):\n",
        "    np.random.seed(ii)\n",
        "    A = np.random.uniform(low = amps[0], high = amps[-1])\n",
        "    test_ds = SinusoidGenerator(K=n_context, seed = ii, amplitude = A, amp_ind= random.randint(0,K_amp-5))\n",
        "    inner_record = eval_sinewave_for_test(model,  test_ds, num_steps=(0, 1, 5, 100), encoder=encoder_w);\n",
        "    errs.append(inner_record[-1][2].numpy())\n",
        "\n",
        "  print('Model is', exp_type, ', meta-test MSE is', np.mean(errs) )"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Model is MR-MAML-A , meta-test MSE is 0.14966752\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "LqOUUaYFdziP",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}